import numpy as np
import torch
import copy
import warnings
from mmcv.cnn.bricks.registry import (ATTENTION,
                                      TRANSFORMER_LAYER,
                                      TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import TransformerLayerSequence
from mmcv.runner import force_fp32, auto_fp16
from mmcv.utils import TORCH_VERSION, digit_version
from mmcv.utils import ext_loader
from .custom_base_transformer_layer import MyCustomBaseTransformerLayer
ext_module = ext_loader.load_ext(
    '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])


@TRANSFORMER_LAYER_SEQUENCE.register_module()
class AFSADecoder(TransformerLayerSequence):

    """
    Attention with both self and cross
    Implements the decoder in DETR transformer.
    Args:
        return_intermediate (bool): Whether to return intermediate outputs.
        coder_norm_cfg (dict): Config of last normalization layer. Default：
            `LN`.
    """

    def __init__(self, *args, pc_range=None, return_intermediate=False, dataset_type='nuscenes',
                 **kwargs):

        super(AFSADecoder, self).__init__(*args, **kwargs)
        self.return_intermediate = return_intermediate

        self.pc_range = pc_range
        self.fp16_enabled = False


    @force_fp32(apply_to=('reference_points', 'img_metas'))
    def point_sampling(self, reference_points, pc_range,  img_metas):

        lidar2img = [each['lidar2img']
                            for img_meta in img_metas for each in img_meta.values()]
        lidar2img = np.asarray(lidar2img)
        lidar2img = reference_points.new_tensor(lidar2img)  # (bs*queue_len, num_cams, 4, 4)
        reference_points = reference_points.clone()

        reference_points[..., 0:1] = reference_points[..., 0:1] * \
            (pc_range[3] - pc_range[0]) + pc_range[0]
        reference_points[..., 1:2] = reference_points[..., 1:2] * \
            (pc_range[4] - pc_range[1]) + pc_range[1]
        reference_points[..., 2:3] = reference_points[..., 2:3] * \
            (pc_range[5] - pc_range[2]) + pc_range[2]

        reference_points = torch.cat(
            (reference_points, torch.ones_like(reference_points[..., :1])), -1) #(bs, queue_len, num_query, z, 4)

        bs, queue_len, num_query, num_pillar = reference_points.size()[:4]
        num_cam = lidar2img.size(1)

        reference_points = reference_points.view(
            bs, queue_len, 1, num_query, num_pillar, 4).repeat(1, 1, num_cam, 1, 1, 1).unsqueeze(-1) #(bs, queue_len, num_cam, num_query, num_pillar, 4, 1)
        eps = 1e-5

        lidar2img = lidar2img.view(
            bs, queue_len, num_cam, 1, 1, 4, 4).repeat(1, 1, 1, num_query, num_pillar, 1, 1) #(bs, queue_len, num_cam, num_query, num_pillar, 4,4)
        eps = 1e-5

        reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
                                            reference_points.to(torch.float32)).squeeze(-1) #(bs, queue_len, num_cam, num_query, num_pillar, 4)
        eps = 1e-5

        bev_mask = (reference_points_cam[..., 2:3] > eps)
        reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
            reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps)

        reference_points_cam[..., 0] /= img_metas[0][0]['img_shape'][0][1]
        reference_points_cam[..., 1] /= img_metas[0][0]['img_shape'][0][0]

        bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)
                    & (reference_points_cam[..., 1:2] < 1.0)
                    & (reference_points_cam[..., 0:1] < 1.0)
                    & (reference_points_cam[..., 0:1] > 0.0))
        if digit_version(TORCH_VERSION) >= digit_version('1.8'):
            bev_mask = torch.nan_to_num(bev_mask)
        else:
            bev_mask = bev_mask.new_tensor(
                np.nan_to_num(bev_mask.cpu().numpy()))

        bev_mask = bev_mask.squeeze(-1)

        # torch.backends.cuda.matmul.allow_tf32 = allow_tf32
        # torch.backends.cudnn.allow_tf32 = allow_tf32

        return reference_points_cam, bev_mask


    # @auto_fp16()
    def forward(self,
                query,
                query_pos,
                reference_points,
                img_feat_flatten,
                pts_feat_flatten,
                img_spatial_shapes=None,
                img_level_start_index=None,
                pts_spatial_shapes=None,
                pts_level_start_index=None,
                img_metas=None,
                key_padding_mask=None,
                attn_masks=None,
                reg_branch=None,
                **kwargs):
        """Forward function for `TransformerDecoder`.
        Args:
            bev_query (Tensor): Input BEV query `.
            key & value (Tensor): Input features 
            reference_points (Tensor): The reference
                points of offset.
            valid_ratios (Tensor): The radios of valid
                points on the feature map.
        Returns:
            Tensor: Results with shape [1, num_query, bs, embed_dims] when
                return_intermediate is `False`, otherwise it has shape
                [num_layers, num_query, bs, embed_dims].
        """

        output = query
        intermediate = []

        reference_points_cam, query_mask = self.point_sampling(
            reference_points, self.pc_range, img_metas)


        for lid, layer in enumerate(self.layers):
            output = layer(
                query,
                query_pos,
                reference_points,
                reference_points_cam,
                query_mask,                
                img_feat_flatten,
                pts_feat_flatten,
                img_spatial_shapes=img_spatial_shapes,
                img_level_start_index=img_level_start_index,
                pts_spatial_shapes=pts_spatial_shapes,
                pts_level_start_index=pts_level_start_index,
                img_metas=img_metas,
                key_padding_mask=key_padding_mask,
                attn_masks=attn_masks,
                **kwargs)

            query = output
            if self.return_intermediate:
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output

@TRANSFORMER_LAYER.register_module()
class AFSALayer(MyCustomBaseTransformerLayer):
    """Implements decoder layer in DETR transformer.
    Args:
        attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
            Configs for self_attention or cross_attention, the order
            should be consistent with it in `operation_order`. If it is
            a dict, it would be expand to the number of attention in
            `operation_order`.
        feedforward_channels (int): The hidden dimension for FFNs.
        ffn_dropout (float): Probability of an element to be zeroed
            in ffn. Default 0.0.
        operation_order (tuple[str]): The execution order of operation
            in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
            Default：None
        act_cfg (dict): The activation config for FFNs. Default: `LN`
        norm_cfg (dict): Config dict for normalization layer.
            Default: `LN`.
        ffn_num_fcs (int): The number of fully-connected layers in FFNs.
            Default：2.
    """

    def __init__(self,
                 attn_cfgs,
                 feedforward_channels,
                 ffn_dropout=0.0,
                 operation_order=None,
                 act_cfg=dict(type='ReLU', inplace=True),
                 norm_cfg=dict(type='LN'),
                 ffn_num_fcs=2,
                 **kwargs):
        super(AFSALayer, self).__init__(
            attn_cfgs=attn_cfgs,
            feedforward_channels=feedforward_channels,
            ffn_dropout=ffn_dropout,
            operation_order=operation_order,
            act_cfg=act_cfg,
            norm_cfg=norm_cfg,
            ffn_num_fcs=ffn_num_fcs,
            **kwargs)
        self.fp16_enabled = False
        assert len(operation_order) == 8
        assert set(operation_order) == set(
            ['self_attn', 'norm', 'img_cross_attn', 'norm',
                 'pts_cross_attn', 'norm','ffn', 'norm'])

    def forward(self,
                query,
                query_pos,
                reference_points,
                reference_points_cam,
                query_mask,
                img_feat_flatten,
                pts_feat_flatten,
                img_spatial_shapes=None,
                img_level_start_index=None,
                pts_spatial_shapes=None,
                pts_level_start_index=None,
                img_metas=None,
                key_padding_mask=None,
                attn_masks=None,
                **kwargs):
        """Forward function for `TransformerDecoderLayer`.

        **kwargs contains some specific arguments of attentions.

        Args:
            query (Tensor): The input querylse
                [bs, num_queries embed_dims].
            img_feat_flatten (Tensor): The key image tensor  .
            pts_feat_flatten(Tensor): The key point cloud tensor  .
            query_pos (Tensor): The positional encoding for `query`.
                Default: None.
            reference_points (Tensor): The reference points of offset in point cloud features.
            reference_points_cam (Tensor): The reference points of offset in image features.
            attn_masks (List[Tensor] | None): 2D Tensor used in
                calculation of corresponding attention. The length of
                it should equal to the number of `attention` in
                `operation_order`. Default: None.
            key_padding_mask (Tensor): ByteTensor for `query`, with
                shape [bs, num_keys]. Default: None.

        Returns:
            Tensor: forwarded results with shape [num_queries, bs, embed_dims].
        """

        norm_index = 0
        attn_index = 0
        ffn_index = 0
        identity = query
        if attn_masks is None:
            attn_masks = [None for _ in range(self.num_attn)]
        elif isinstance(attn_masks, torch.Tensor):
            attn_masks = [
                copy.deepcopy(attn_masks) for _ in range(self.num_attn)
            ]
            warnings.warn(f'Use same attn_mask in all attentions in '
                          f'{self.__class__.__name__} ')
        else:
            assert len(attn_masks) == self.num_attn, f'The length of ' \
                                                     f'attn_masks {len(attn_masks)} must be equal ' \
                                                     f'to the number of attention in ' \
                f'operation_order {self.num_attn}'

        for layer in self.operation_order:
            # temporal self attention
            if layer == 'self_attn':

                temp_key = temp_value = query
                query = self.attentions[attn_index](
                    query,
                    temp_key,
                    temp_value,
                    identity if self.pre_norm else None,
                    query_pos=query_pos,
                    key_pos=query_pos,
                    attn_mask=attn_masks[attn_index],
                    key_padding_mask=key_padding_mask,
                    **kwargs)
                attn_index += 1
                identity = query

            elif layer == 'norm':
                query = self.norms[norm_index](query)
                norm_index += 1

            # img cross attention
            elif layer == 'img_cross_attn':
                query = self.attentions[attn_index](
                    query,
                    img_feat_flatten,
                    img_feat_flatten,
                    identity if self.pre_norm else None,
                    query_pos=None,
                    reference_points_cam=reference_points_cam,
                    spatial_shapes=img_spatial_shapes,
                    level_start_index=img_level_start_index,
                    query_mask=query_mask,
                    key_padding_mask=key_padding_mask,
                    img_metas=img_metas,
                    **kwargs)
                attn_index += 1
                identity = query

            # point cloud cross attention
            elif layer == 'pts_cross_attn':
                query = self.attentions[attn_index](
                    query,
                    pts_feat_flatten,
                    pts_feat_flatten,
                    identity if self.pre_norm else None,
                    query_pos=None,
                    reference_points=reference_points,
                    spatial_shapes=pts_spatial_shapes,
                    level_start_index=pts_level_start_index,
                    key_padding_mask=key_padding_mask,
                    img_metas=img_metas,
                    **kwargs)
                attn_index += 1
                identity = query

            elif layer == 'ffn':
                query = self.ffns[ffn_index](
                    query, identity if self.pre_norm else None)
                ffn_index += 1

        return query
